{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import commpy as cp\n",
    "import scipy.signal as sig\n",
    "import scipy.linalg as la\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from multiprocessing import Pool\n",
    "\n",
    "import traceback\n",
    "import warnings\n",
    "import sys\n",
    "\n",
    "%matplotlib inline\n",
    "DEFAULT_SEED = 100\n",
    "np.random.seed(DEFAULT_SEED)#set the random generator seed to default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def modulate(data, mod_scheme='BPSK', demod=False):\n",
    "    \"\"\"  1. Modulates (or demodulates) data according to the modulation scheme \"\"\"\n",
    "    mod_schemes = ['BPSK', 'QPSK']\n",
    "    data = data.flatten()\n",
    "    if mod_scheme not in mod_schemes:\n",
    "        raise ValueError('Unknown modulation scheme, please choose from: '+ ' '.join(mod_schemes))\n",
    "    elif mod_scheme == 'QPSK':\n",
    "        modulator = cp.modulation.QAMModem(4)\n",
    "        if demod:\n",
    "            return modulator.demodulate(data, \"hard\")\n",
    "        return modulator.modulate(data)\n",
    "    elif mod_scheme == 'BPSK':\n",
    "        def bpsk_one(x):\n",
    "            if demod:\n",
    "                return 0 if x < 0 else 1\n",
    "            return -1 if x==0 else 1\n",
    "        bpsk = np.vectorize(bpsk_one)\n",
    "        return bpsk(data)\n",
    "\n",
    "def apply_channel(signal, channel_function):\n",
    "    \"\"\"  2. Convolves signal with channel_function \"\"\"\n",
    "    channel_output = sig.convolve(signal, channel_function, mode='full') # convolve input complex data with the channel transfer function\n",
    "    return channel_output\n",
    "\n",
    "def add_awgn_noise(signal, SNR_dB):\n",
    "    \"\"\"  3. Adds AWGN noise vector to signal  \n",
    "            to generate a resulting signal vector y of specified SNR in dB\n",
    "    \"\"\"\n",
    "    L=len(signal)\n",
    "    SNR = 10**(SNR_dB/10.0) #SNR to linear scale\n",
    "    Esym=np.sum(np.square(np.abs(signal)))/L #Calculate actual symbol energy\n",
    "    N0=Esym/SNR; #Find the noise spectral density\n",
    "    if(isinstance(signal[0], complex)):\n",
    "        noiseSigma=np.sqrt(N0/2.0)#Standard deviation for AWGN Noise when x is complex\n",
    "        n = noiseSigma*(np.random.randn(1,L)+1j*np.random.randn(1,L))#computed noise \n",
    "    else:\n",
    "        noiseSigma = np.sqrt(N0);#Standard deviation for AWGN Noise when x is real\n",
    "        n = noiseSigma*np.random.randn(1,L)#computed noise\n",
    "    y = signal + n #received signal\n",
    "    return y.flatten()\n",
    "\n",
    "def num_bit_errs(in_bits, out_bits):\n",
    "    total = 0\n",
    "    for i in range(len(in_bits)):\n",
    "        if in_bits[i] != out_bits[i]:\n",
    "            total += 1\n",
    "    return total\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Baseline:\n",
    "    def __init__(self, params={}):\n",
    "        pass\n",
    "    \n",
    "    def train(self, x, y, params={}):\n",
    "        pass\n",
    "    \n",
    "    def predict(self, x):\n",
    "        return x\n",
    "    \n",
    "class LeastMeanSquares:\n",
    "    def __init__(self, init_params={'equalizer_order':None,'random_starts':True, 'learning_rate':0.01}):\n",
    "        if (not init_params['equalizer_order']):\n",
    "            raise ValueError(\"LeastMeanSquares: init_params['equalizer_order'] is missing\")\n",
    "        self.order = init_params['equalizer_order']\n",
    "        if (self.order % 2 == 0):\n",
    "            raise ValueError(\"LeastMeanSquares: init_params['equalizer_order'] must be odd\")\n",
    "        if (self.order < 3):\n",
    "            raise ValueError(\"LeastMeanSquares: init_params['equalizer_order'] must be at least 3\")\n",
    "        self.h = None\n",
    "        self.random_starts = init_params['random_starts']\n",
    "        self.L = (self.order-1)//2\n",
    "        self.mu = init_params['learning_rate']\n",
    "        \n",
    "    def train(self, x, y, train_params={}):\n",
    "        mu = self.mu\n",
    "        if (self.h is None):\n",
    "            if (self.random_starts):\n",
    "                self.h = np.random.randn(order) + 1j*np.random.randn(self.order) if isinstance(x[0], complex) else np.random.randn(self.order)\n",
    "            else:\n",
    "                self.h = np.zeros(self.order, dtype=np.complex_ if isinstance(x[0], complex) else np.float32)\n",
    "        x = np.pad(x, self.L, 'constant', constant_values=(0j if isinstance(x[0], complex) else 0.0))\n",
    "        for i in range(len(y)):\n",
    "            r = np.flip(x[i: i + self.order],0)\n",
    "            symbol = y[i]\n",
    "            self.h = self.h + mu * (symbol - np.dot(r, self.h)) * r.conj()\n",
    "    \n",
    "    def predict(self, x):\n",
    "        if (self.h is None):\n",
    "            if (self.random_starts):\n",
    "                self.h = np.random.randn(self.order) + 1j*np.random.randn(self.order) if isinstance(x[0], complex) else np.random.randn(self.order)\n",
    "            else:\n",
    "                self.h = np.zeros(self.order, dtype=np.complex_ if isinstance(x[0], complex) else np.float32)\n",
    "        return sig.convolve(x, self.h , mode=\"full\")[self.L:]\n",
    "    \n",
    "class LeastSquares:\n",
    "    def __init__(self, init_params={'equalizer_order':None,'random_starts':True, 'learning_rate':0.01, 'steps':1000}):\n",
    "        if (not init_params['equalizer_order']):\n",
    "            raise ValueError(\"LeastSquares: init_params['equalizer_order'] is missing\")\n",
    "        self.order = init_params['equalizer_order']\n",
    "        if (self.order % 2 == 0):\n",
    "            raise ValueError(\"LeastSquares: init_params['equalizer_order'] must be odd\")\n",
    "        if (self.order < 3):\n",
    "            raise ValueError(\"LeastSquares: init_params['equalizer_order'] must be at least 3\")\n",
    "        self.h = None\n",
    "        self.order = init_params['equalizer_order']\n",
    "        self.random_starts = init_params['random_starts']\n",
    "        self.L = (self.order-1)//2\n",
    "        self.steps = init_params['steps']\n",
    "        self.mu = init_params['learning_rate']\n",
    "        \n",
    "    def train(self, x, y, params={}):\n",
    "        mu = self.mu\n",
    "        steps = self.steps\n",
    "        if (steps == 0):\n",
    "            self.train_closed_form(x, y)\n",
    "        else:\n",
    "            if (self.h is None):\n",
    "                if (self.random_starts):\n",
    "                    self.h = np.random.randn(self.order) + 1j*np.random.randn(self.order) if isinstance(x[0], complex) else np.random.randn(self.order)\n",
    "                else:\n",
    "                    self.h = np.zeros(self.order, dtype=np.complex_ if isinstance(x[0], complex) else np.float32)\n",
    "            constant = 0j if isinstance(x[0], complex) else 0.0\n",
    "            A = []\n",
    "            x = np.pad(x, self.L, 'constant', constant_values=(constant))\n",
    "            for i in range(len(y)):\n",
    "                A += [np.flip(x[i: i+self.order],0)]\n",
    "            A = np.array(A)\n",
    "            while steps > 0:\n",
    "                grad_update = np.dot(A.T, np.dot(A, self.h) - y)\n",
    "                if(grad_update[0]== float('inf') or grad_update[0]== float('-inf')): \n",
    "                    break\n",
    "                self.h = self.h - (mu / len(y)) * grad_update\n",
    "                steps -= 1\n",
    "\n",
    "    def train_closed_form(self, x, y):\n",
    "        constant = 0j if isinstance(x[0], complex) else 0.0\n",
    "        A = []\n",
    "        x = np.pad(x, self.L, 'constant', constant_values=(constant))\n",
    "        for i in range(len(y)):\n",
    "            A += [np.flip(x[i: i+self.order],0)]\n",
    "        A = np.array(A)\n",
    "        h,_,_,_ = np.linalg.lstsq(A, y,rcond=-1)\n",
    "        self.h = h\n",
    "    \n",
    "    def predict(self, x):\n",
    "        if (self.h is None):\n",
    "            if (self.random_starts):\n",
    "                self.h = np.random.randn(self.order) + 1j*np.random.randn(self.order) if isinstance(x[0], complex) else np.random.randn(self.order)\n",
    "            else:\n",
    "                self.h = np.zeros(self.order, dtype=np.complex_ if isinstance(x[0], complex) else np.float32)\n",
    "        return sig.convolve(x, self.h , mode=\"full\")[self.L:]\n",
    "    \n",
    "class ZeroForcing:\n",
    "    def __init__(self, init_params={}):\n",
    "        self.channel = None\n",
    "    \n",
    "    def train(self, x, y, channel_taps):\n",
    "        self.channel = channel_taps\n",
    "    \n",
    "    def predict(self, x):\n",
    "        if (self.channel is None):\n",
    "            raise ValueError(\"ZeroForcing.predict: must train with channel first\")\n",
    "        if(isinstance(x[0], complex)):\n",
    "            freq_domain = np.fft.fft(x, len(x))/np.fft.fft(self.channel, len(x))\n",
    "            return np.fft.ifft(freq_domain)[0:len(x) - len(self.channel) + 1]\n",
    "        else:\n",
    "            freq_domain = np.fft.fft(x, len(x))/np.fft.fft(self.channel, len(x))\n",
    "            return np.real(np.fft.ifft(freq_domain)[0:len(x) - len(self.channel) + 1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline bit error rate:  1.0\n",
      "Least Mean Squares bit error rate:  0.0\n",
      "Least Squares bit error rate:  0.0\n",
      "Zero Forcing bit error rate:  0.0\n",
      "[-0.99868121 -0.05134035]\n"
     ]
    }
   ],
   "source": [
    "LMS = LeastMeanSquares(init_params={'equalizer_order':5,'random_starts':True, 'learning_rate':0.01})\n",
    "LS = LeastSquares(init_params={'equalizer_order':5,'random_starts':True, 'learning_rate':0.01, 'steps':1000})\n",
    "ZF = ZeroForcing()\n",
    "\n",
    "channel_length = 2\n",
    "preamble_length = 400\n",
    "data_length = 100\n",
    "\n",
    "modulation_scheme = 'BPSK'\n",
    "\n",
    "snr = 10\n",
    "\n",
    "channel_function = np.random.randn(channel_length) \n",
    "channel_function = channel_function / np.linalg.norm(channel_function)\n",
    "\n",
    "# generate training data\n",
    "preamble_bits = np.random.randint(0,2, preamble_length) \n",
    "train_symbols = modulate(preamble_bits, modulation_scheme)\n",
    "train_signal = add_awgn_noise(apply_channel(train_symbols, channel_function), snr)\n",
    "# generate testing data\n",
    "test_bits = np.random.randint(0,2, data_length)\n",
    "test_symbols = modulate(test_bits, modulation_scheme)\n",
    "test_signal = add_awgn_noise(apply_channel(test_symbols, channel_function), snr)\n",
    "# calculating baseline error (no training)\n",
    "baseline_bits = modulate(test_signal, modulation_scheme, True)\n",
    "\n",
    "\n",
    "LMS.train(train_signal, train_symbols)\n",
    "lms_test_output = LMS.predict(test_signal)\n",
    "\n",
    "LS.train(train_signal, train_symbols)\n",
    "ls_test_output = LMS.predict(test_signal)\n",
    "\n",
    "ZF.train(train_signal, train_symbols, channel_function)\n",
    "zf_test_output = ZF.predict(test_signal)\n",
    "\n",
    "\n",
    "b_ber = num_bit_errs(test_bits, baseline_bits)\n",
    "print(\"Baseline bit error rate: \", b_ber/data_length)\n",
    "\n",
    "lms_ber = num_bit_errs(test_bits, modulate(lms_test_output, modulation_scheme, demod=True))\n",
    "print(\"Least Mean Squares bit error rate: \", lms_ber/data_length)\n",
    "\n",
    "ls_ber = num_bit_errs(test_bits, modulate(ls_test_output, modulation_scheme, demod=True))\n",
    "print(\"Least Squares bit error rate: \", ls_ber/data_length)\n",
    "\n",
    "zf_ber = num_bit_errs(test_bits, modulate(zf_test_output, modulation_scheme, demod=True))\n",
    "print(\"Zero Forcing bit error rate: \", zf_ber/data_length)\n",
    "\n",
    "print(channel_function)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
